import argparse
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
from itertools import combinations
import cv2
from scipy.signal import find_peaks

# List of all body pose landmarks (PREDEFINED)
body_pose_landmarks = ["nose", "left eye inner", "left eye", "left eye outer", "right eye inner", "right eye",
                       "right eye outer", "left ear", "right ear", "mouth left", "mouth right", "left shoulder",
                       "right shoulder", "left elbow", "right elbow", "left wrist", "right wrist", "left pinky",
                       "right pinky", "left index", "right index", "left thumb", "right thumb", "left hip",
                       "right hip", "left knee", "right knee", "left ankle", "right ankle", "left heel",
                       "right heel", "left foot index", "right foot index"]

remove_list = ['left eye', 'left eye inner', 'left eye outer', 'left ear', 'right eye', 'right eye inner',
               'right eye outer', 'right ear', 'mouth left', 'mouth right', 'left pinky', 'right pinky',
               'left thumb', 'right thumb', 'left heel', 'right heel']


# Function to calculate angle between two lines
def calculate_angle(point1, point2, point3):
    if point1 == (0, 0) or point2 == (0, 0) or point3 == (0, 0):
        return 0
    numerator = point2[1] * (point1[0] - point3[0]) + point1[1] * \
        (point3[0] - point2[0]) + point3[1] * (point2[0] - point1[0])
    denominator = (point2[0] - point1[0]) * (point1[0] - point3[0]) + \
        (point2[1] - point1[1]) * (point1[1] - point3[1])
    try:
        ang = math.atan(numerator/denominator)
        ang = ang * 180 / math.pi
        if ang < 0:
            ang = 180 + ang
        return ang
    except:
        return 90.0

# Function to process CSV file and extract features
def process_csv(input_csv):
    df = pd.read_csv(input_csv)

    col_name = []
    for i in body_pose_landmarks:
        col_name += [i + '_X', i + '_Y', i + '_Z', i + '_V']

    df.columns = col_name + ['file_path']

    # Remove columns with '_V'
    df = df[df.columns[~df.columns.str.contains('_V')]]

    # Remove specified body landmarks
    for i in remove_list:
        df = df[df.columns[~df.columns.str.contains(i)]]

    # Sort dataframe
    df['frame_idx'] = df['file_path'].apply(lambda x: int(x.split('/')[-1].split('_')[0]))
    df = df.sort_values(by=['frame_idx']).reset_index(drop=True)

    return df

# Function to calculate angles and create feature dataframe
def make_feature(df):

    feature_df = pd.DataFrame()

    new_pose_landmarks = body_pose_landmarks.copy()
    for landmark in remove_list:
        new_pose_landmarks.remove(landmark)

    all_angles = list(combinations(new_pose_landmarks, 3))

    i = 0
    mapping = {}
    for i in range(len(all_angles)):
        feature_df['f' + str(i+1)] = df.apply(lambda x: calculate_angle((x[all_angles[i][0] + '_X'], x[all_angles[i][0] + '_Y']), (x[all_angles[i][1] + '_X'], x[all_angles[i][1] + '_Y']), (x[all_angles[i][2] + '_X'], x[all_angles[i][2] + '_Y'])), axis=1)
        mapping['f' + str(i+1)] = all_angles[i]

    feature_df['file_path'] = df['file_path']

    return feature_df

# Function to calculate standard deviation of each feature with next n frames
def calculate_std(feature_df, n=3):

    std_list = []
    for i in range(len(feature_df)):
        std_list.append(np.std(feature_df.loc[i-3:i+2,:]).mean())

    feature_df['std'] = std_list

    return feature_df

# Function to find prominent peaks in the standard deviation
def find_prominent_peaks(feature_df, desired_peak_count=10):
    min_prominence = 1.0

    while True:
        peaks, properties = find_peaks(
            feature_df['std'], distance=3, prominence=min_prominence)

        
        if len(peaks) >= desired_peak_count:
            break
        min_prominence -= 0.1
        
    print("Peaks achieved with prominence: ", min_prominence)
    prominant_peak_idx = np.argsort(properties['prominences'])[-10:]
    peaks = peaks[prominant_peak_idx]

    return peaks

# Function to display graphs
def display_graphs(df, feature_df, peaks):

    plt.plot(df['frame_idx'], feature_df['std'], marker='o')
    plt.plot(df['frame_idx'][peaks], feature_df['std'][peaks], "x", ms=10, color='#ff7f0e', mew=3)

    for i in range(len(peaks)):
        plt.axvline(x=df['frame_idx'][peaks[i]],
                    linestyle='--', color='#ff7f0e', alpha=0.3)

    plt.xlabel('Frame Index')
    plt.ylabel('Standard Deviation')
    plt.title('Standard Deviation of Mean of Angles')
    plt.show()

# Function to display top frames
def display_top_frames(top_frames):

    fig, ax = plt.subplots(2, 5, figsize=(20, 5))
    fig.suptitle('Top Frames')
    ax = ax.ravel()

    for i in range(len(top_frames)):
        img = cv2.imread(top_frames[i])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ax[i].imshow(img)
        ax[i].set_title('Frame ' + str(i+1))
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    plt.show()


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Process CSV file and analyze body pose data.')
    parser.add_argument('input_csv', type=str,help='Path to the input CSV file')
    parser.add_argument('--display_graphs',action='store_true', help='Display graphs')
    parser.add_argument('--display_top_frames',action='store_true', help='Display top frames')

    args = parser.parse_args()

    df = process_csv(args.input_csv)
    feature_df = make_feature(df)
    feature_df = calculate_std(feature_df)

    if args.display_graphs:
        peaks = find_prominent_peaks(feature_df)
        display_graphs(df, feature_df, peaks)

    if args.display_top_frames:
        peaks = find_prominent_peaks(feature_df)
        top_frames = sorted(feature_df['file_path'][peaks], key=lambda x: int(x.split('/')[-1].split('_')[0]))
        display_top_frames(top_frames)
